package cn.wzl.perceptron.perceptron;

import cn.wzl.perceptron.utils.RandomUtil;

public abstract class SingleNodePerceptron {

    protected double[] w;
    protected double[] dw;

    protected double b;
    protected double[] x;
    protected double eta;

    protected double weightedValue;

    protected int inputSourceCount;

    public double training(double[] x, double expectedValue) {
        this.x = x;
        double actualValue = forwardPropagation(x);
        backPropagation(expectedValue, actualValue);
        return actualValue;
    }

    public double run(double[] x) {
        return forwardPropagation(x);
    }

    public SingleNodePerceptron(int inputSourceCount, double eta){
        this.w = new double[inputSourceCount];
        this.dw = new double[inputSourceCount];
        for(int i = 0; i < inputSourceCount; i ++) {
            w[i] = RandomUtil.randomDouble();
        }
        this.b = RandomUtil.randomDouble();
        this.eta = eta;
        this.inputSourceCount = inputSourceCount;
    }

    private double forwardPropagation(double[] x) {
        this.x = x;
        double sum = 0;
        for(int i = 0; i < inputSourceCount; i ++) {
            sum += w[i] * x[i];
        }
        this.weightedValue = sum + b;
        return active(weightedValue);
    }

    private void backPropagation(double expectedValue, double actualValue) {
        double dy = loss(expectedValue, actualValue);
        for(int i = 0; i < inputSourceCount; i ++) {
            dw[i] = eta*x[i]*dActive(dy);
            w[i] += dw[i];
        }
        double db = dActive(dy);
        b += (eta*db);
    }


    abstract protected double loss(double expectedValue, double actualValue);

    abstract protected double dActive(double dy);

    abstract protected double active(double weightedValue);



}
